/***************************************************************************
* Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay          *
* Copyright (c) QuantStack                                                 *
*                                                                          *
* Distributed under the terms of the BSD 3-Clause License.                 *
*                                                                          *
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/

// This file is generated from test/files/cppy_source/test_qr.cppy by preprocess.py!

#include <algorithm>

#include "gtest/gtest.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xfixed.hpp"
#include "xtensor/xnoalias.hpp"
#include "xtensor/xstrided_view.hpp"
#include "xtensor/xtensor.hpp"
#include "xtensor/xview.hpp"

#include "xtensor-blas/xlinalg.hpp"

namespace xt
{
    using namespace xt::placeholders;

    /*py
    a = np.random.random((6, 3))
    res_q1 = np.linalg.qr(a, 'raw')
    res_q2 = np.linalg.qr(a, 'complete')
    res_q3 = np.linalg.qr(a, 'reduced')
    res_q4 = np.linalg.qr(a, 'r')
    */
    TEST(xtest_extended, qr1)
    {
        // py_a
        xarray<double> py_a = {{0.3745401188473625,0.9507143064099162,0.7319939418114051},
                               {0.5986584841970366,0.1560186404424365,0.1559945203362026},
                               {0.0580836121681995,0.8661761457749352,0.6011150117432088},
                               {0.7080725777960455,0.0205844942958024,0.9699098521619943},
                               {0.8324426408004217,0.2123391106782762,0.1818249672071006},
                               {0.1834045098534338,0.3042422429595377,0.5247564316322378}};
        // py_resq1_h = res_q1[0]
        xarray<double> py_resq1_h = {{-1.3152987216651169, 0.3542695728401418, 0.0343722790456067,
                                       0.4190178144924799, 0.4926165861757361, 0.1085337284576868},
                                     {-0.567877094797874 , 1.2223138676385652,-0.5073775633545011,
                                       0.3838046167052855, 0.3339455785740943,-0.0869071101793681},
                                     {-1.0163710885529547, 0.7215655008695085, 0.7854784971183756,
                                      -0.8184018010449023, 0.3355103841692941,-0.2743559826773574}};
        // py_resq1_tau = res_q1[1]
        xarray<double> py_resq1_tau = {1.2847566964660388,1.3124991842889797,1.0766465015522177};
        
        auto res1 = linalg::qr(py_a, linalg::qrmode::raw);
        EXPECT_TRUE(allclose(std::get<0>(res1), py_resq1_h));
        EXPECT_TRUE(allclose(std::get<1>(res1), py_resq1_tau));

        // py_resq2_q_cmpl = res_q2[0]
        xarray<double> py_resq2_q_cmpl = {{-0.2847566964660388, 0.6455031901264903,-0.0295327810119745,
                                           -0.5849049416686276,-0.0730618203174815,-0.3923203408230155},
                                          {-0.4551502060605353,-0.0838170448559192,-0.3133472182914374,
                                            0.0819245453270295,-0.7892351407115685, 0.2408791714587237},
                                          {-0.0441600156766425, 0.6881200538051699, 0.0760152664601146,
                                            0.7143224973945711, 0.0235700722943727, 0.0891638112668338},
                                          {-0.538335943107778 ,-0.2332659103773061, 0.7525061466150679,
                                            0.1447692100263401,-0.0279639819291248,-0.2603378924852557},
                                          {-0.6328924578795164,-0.1203177215897514,-0.4769214096589269,
                                            0.1040507467269481, 0.5878955555305321, 0.0326957112268427},
                                          {-0.1394394344284399, 0.1841243791750922, 0.31850193596774  ,
                                           -0.3303532438685529, 0.1575155429538277, 0.8433664457979998}};
        // py_resq2_r_cmpl = res_q2[1]
        xarray<double> py_resq2_r_cmpl = {{-1.3152987216651169,-0.567877094797874 ,-1.0163710885529547},
                                          { 0.                , 1.2223138676385652, 0.7215655008695085},
                                          { 0.                , 0.                , 0.7854784971183756},
                                          { 0.                , 0.                , 0.                },
                                          { 0.                , 0.                , 0.                },
                                          { 0.                , 0.                , 0.                }};

        auto res2 = linalg::qr(py_a, linalg::qrmode::complete);
        EXPECT_TRUE(allclose(std::get<0>(res2), py_resq2_q_cmpl));
        EXPECT_TRUE(allclose(std::get<1>(res2), py_resq2_r_cmpl));

        // py_resq3_q_cmpl = res_q3[0]
        xarray<double> py_resq3_q_cmpl = {{-0.2847566964660388, 0.6455031901264903,-0.0295327810119745},
                                          {-0.4551502060605353,-0.0838170448559192,-0.3133472182914374},
                                          {-0.0441600156766425, 0.6881200538051699, 0.0760152664601146},
                                          {-0.538335943107778 ,-0.2332659103773061, 0.7525061466150679},
                                          {-0.6328924578795164,-0.1203177215897514,-0.4769214096589269},
                                          {-0.1394394344284399, 0.1841243791750922, 0.31850193596774  }};
        // py_resq3_r_cmpl = res_q3[1]
        xarray<double> py_resq3_r_cmpl = {{-1.3152987216651169,-0.567877094797874 ,-1.0163710885529547},
                                          { 0.                , 1.2223138676385652, 0.7215655008695085},
                                          { 0.                , 0.                , 0.7854784971183756}};

        auto res3 = linalg::qr(py_a, linalg::qrmode::reduced);
        EXPECT_TRUE(allclose(std::get<0>(res3), py_resq3_q_cmpl));
        EXPECT_TRUE(allclose(std::get<1>(res3), py_resq3_r_cmpl));

        // py_resq4_r_r = res_q4
        xarray<double> py_resq4_r_r = {{-1.3152987216651169,-0.567877094797874 ,-1.0163710885529547},
                                       { 0.                , 1.2223138676385652, 0.7215655008695085},
                                       { 0.                , 0.                , 0.7854784971183756}};

        auto res4 = linalg::qr(py_a, linalg::qrmode::r);
        EXPECT_TRUE(allclose(std::get<1>(res4), py_resq4_r_r));
    }

    /*py
    a = np.random.random((5, 10))
    res_q1 = np.linalg.qr(a, 'raw')
    res_q2 = np.linalg.qr(a, 'complete')
    res_q3 = np.linalg.qr(a, 'reduced')
    res_q4 = np.linalg.qr(a, 'r')
    */
    TEST(xtest_extended, qr2)
    {
        // py_a
        xarray<double> py_a = {{0.4319450186421158,0.2912291401980419,0.6118528947223795,
                                0.1394938606520418,0.2921446485352182,0.3663618432936917,
                                0.4560699842170359,0.7851759613930136,0.1996737821583597,
                                0.5142344384136116},
                               {0.5924145688620425,0.0464504127199977,0.6075448519014384,
                                0.1705241236872915,0.0650515929852795,0.9488855372533332,
                                0.9656320330745594,0.8083973481164611,0.3046137691733707,
                                0.0976721140063839},
                               {0.6842330265121569,0.4401524937396013,0.1220382348447788,
                                0.4951769101112702,0.0343885211152184,0.9093204020787821,
                                0.2587799816000169,0.662522284353982 ,0.311711076089411 ,
                                0.5200680211778108},
                               {0.5467102793432796,0.184854455525527 ,0.9695846277645586,
                                0.7751328233611146,0.9394989415641891,0.8948273504276488,
                                0.5978999788110851,0.9218742350231168,0.0884925020519195,
                                0.1959828624191452},
                               {0.0452272889105381,0.3253303307632643,0.388677289689482 ,
                                0.2713490317738959,0.8287375091519293,0.3567533266935893,
                                0.2809345096873808,0.5426960831582485,0.1409242249747626,
                                0.8021969807540397}};
        // py_resq1_h = res_q1[0]
        xarray<double> py_resq1_h = {{-1.1430852952870696, 0.3761289948662397, 0.4344253062693247,
                                       0.3471109568548026, 0.0287151863113738},
                                     {-0.4988738747365853, 0.4145384440977922,-0.1456730968857621,
                                       0.1343802288038163,-0.4549175132696516},
                                     {-1.0982282164248067, 0.0432498341745755, 0.8009723247566577,
                                      -0.2697221220857602,-0.2118640849148783},
                                     {-0.8189559577243967, 0.2159221672678357, 0.2467828455102148,
                                      -0.4358731022610104, 0.0126894274012747},
                                     {-0.6468222288756241, 0.5399745339753013, 0.9011434603476536,
                                      -0.3516828694145329, 0.1205612964483228},
                                     {-1.6166030169206462, 0.0627336303098124, 0.1745159258713335,
                                      -0.1676233275811677, 0.3369911999240203},
                                     {-1.1247642047094615,-0.1631138338388988, 0.4469666475320985,
                                       0.229673631977487 , 0.3155802843315489},
                                     {-1.5746170823854417, 0.2876936477590398, 0.5186696050660637,
                                       0.0972324032495854, 0.1124970816045023},
                                     {-0.4678059691956431, 0.0924634343088705,-0.0398310260167535,
                                       0.1199094213119632, 0.1189824829973467},
                                     {-0.6817147826175952, 0.8209704648352938, 0.1936105292921998,
                                       0.1556371881989978, 0.1610633542281174}};
        // py_resq1_tau = res_q1[1]
        xarray<double> py_resq1_tau = {1.3778764545594464,1.6048419481909388,1.7894907284949315,
                                       1.9996780087119976,0.                };
        
        auto res1 = linalg::qr(py_a, linalg::qrmode::raw);
        EXPECT_TRUE(allclose(std::get<0>(res1), py_resq1_h));
        EXPECT_TRUE(allclose(std::get<1>(res1), py_resq1_tau));
        // py_resq2_q_cmpl = res_q2[0]
        xarray<double> py_resq2_q_cmpl = {{-0.3778764545594464, 0.2477850983490842, 0.2323946032026168,
                                            0.6442783657634198,-0.5715855718413764},
                                          {-0.5182592859033026,-0.5116427882060655, 0.0755411199296714,
                                            0.3718390470559857, 0.5706647283090043},
                                          {-0.5985844007732788, 0.3414264138444295,-0.6868036045376356,
                                           -0.2311050279755064, 0.0039992397751236},
                                          {-0.4782760145698324,-0.1296501056095487, 0.5617369601749853,
                                           -0.6259003290333343,-0.217125009365484 },
                                          {-0.0395659791067296, 0.737185903526121 , 0.3912015899373974,
                                            0.0384753772446496, 0.5481536630508244}};
        // py_resq2_r_cmpl = res_q2[1]
        xarray<double> py_resq2_r_cmpl = {{-1.1430852952870696,-0.4988738747365853,-1.0982282164248067,
                                           -0.8189559577243967,-0.6468222288756241,-1.6166030169206462,
                                           -1.1247642047094615,-1.5746170823854417,-0.4678059691956431,
                                           -0.6817147826175952},
                                          { 0.                , 0.4145384440977922, 0.0432498341745755,
                                            0.2159221672678357, 0.5399745339753013, 0.0627336303098124,
                                           -0.1631138338388988, 0.2876936477590398, 0.0924634343088705,
                                            0.8209704648352938},
                                          { 0.                , 0.                , 0.8009723247566577,
                                            0.2467828455102148, 0.9011434603476536, 0.1745159258713335,
                                            0.4469666475320985, 0.5186696050660637,-0.0398310260167535,
                                            0.1936105292921998},
                                          { 0.                , 0.                , 0.                ,
                                           -0.4358731022610104,-0.3516828694145329,-0.1676233275811677,
                                            0.229673631977487 , 0.0972324032495854, 0.1199094213119632,
                                            0.1556371881989978},
                                          { 0.                , 0.                , 0.                ,
                                            0.                , 0.1205612964483228, 0.3369911999240203,
                                            0.3155802843315489, 0.1124970816045023, 0.1189824829973467,
                                            0.1610633542281174}};

        auto res2 = linalg::qr(py_a, linalg::qrmode::complete);
        EXPECT_TRUE(allclose(std::get<0>(res2), py_resq2_q_cmpl));
        EXPECT_TRUE(allclose(std::get<1>(res2), py_resq2_r_cmpl));

        // py_resq3_q_cmpl = res_q3[0]
        xarray<double> py_resq3_q_cmpl = {{-0.3778764545594464, 0.2477850983490842, 0.2323946032026168,
                                            0.6442783657634198,-0.5715855718413764},
                                          {-0.5182592859033026,-0.5116427882060655, 0.0755411199296714,
                                            0.3718390470559857, 0.5706647283090043},
                                          {-0.5985844007732788, 0.3414264138444295,-0.6868036045376356,
                                           -0.2311050279755064, 0.0039992397751236},
                                          {-0.4782760145698324,-0.1296501056095487, 0.5617369601749853,
                                           -0.6259003290333343,-0.217125009365484 },
                                          {-0.0395659791067296, 0.737185903526121 , 0.3912015899373974,
                                            0.0384753772446496, 0.5481536630508244}};
        // py_resq3_r_cmpl = res_q3[1]
        xarray<double> py_resq3_r_cmpl = {{-1.1430852952870696,-0.4988738747365853,-1.0982282164248067,
                                           -0.8189559577243967,-0.6468222288756241,-1.6166030169206462,
                                           -1.1247642047094615,-1.5746170823854417,-0.4678059691956431,
                                           -0.6817147826175952},
                                          { 0.                , 0.4145384440977922, 0.0432498341745755,
                                            0.2159221672678357, 0.5399745339753013, 0.0627336303098124,
                                           -0.1631138338388988, 0.2876936477590398, 0.0924634343088705,
                                            0.8209704648352938},
                                          { 0.                , 0.                , 0.8009723247566577,
                                            0.2467828455102148, 0.9011434603476536, 0.1745159258713335,
                                            0.4469666475320985, 0.5186696050660637,-0.0398310260167535,
                                            0.1936105292921998},
                                          { 0.                , 0.                , 0.                ,
                                           -0.4358731022610104,-0.3516828694145329,-0.1676233275811677,
                                            0.229673631977487 , 0.0972324032495854, 0.1199094213119632,
                                            0.1556371881989978},
                                          { 0.                , 0.                , 0.                ,
                                            0.                , 0.1205612964483228, 0.3369911999240203,
                                            0.3155802843315489, 0.1124970816045023, 0.1189824829973467,
                                            0.1610633542281174}};

        auto res3 = linalg::qr(py_a, linalg::qrmode::reduced);
        EXPECT_TRUE(allclose(std::get<0>(res3), py_resq3_q_cmpl));
        EXPECT_TRUE(allclose(std::get<1>(res3), py_resq3_r_cmpl));

        // py_resq4_r_r = res_q4
        xarray<double> py_resq4_r_r = {{-1.1430852952870696,-0.4988738747365853,-1.0982282164248067,
                                        -0.8189559577243967,-0.6468222288756241,-1.6166030169206462,
                                        -1.1247642047094615,-1.5746170823854417,-0.4678059691956431,
                                        -0.6817147826175952},
                                       { 0.                , 0.4145384440977922, 0.0432498341745755,
                                         0.2159221672678357, 0.5399745339753013, 0.0627336303098124,
                                        -0.1631138338388988, 0.2876936477590398, 0.0924634343088705,
                                         0.8209704648352938},
                                       { 0.                , 0.                , 0.8009723247566577,
                                         0.2467828455102148, 0.9011434603476536, 0.1745159258713335,
                                         0.4469666475320985, 0.5186696050660637,-0.0398310260167535,
                                         0.1936105292921998},
                                       { 0.                , 0.                , 0.                ,
                                        -0.4358731022610104,-0.3516828694145329,-0.1676233275811677,
                                         0.229673631977487 , 0.0972324032495854, 0.1199094213119632,
                                         0.1556371881989978},
                                       { 0.                , 0.                , 0.                ,
                                         0.                , 0.1205612964483228, 0.3369911999240203,
                                         0.3155802843315489, 0.1124970816045023, 0.1189824829973467,
                                         0.1610633542281174}};

        auto res4 = linalg::qr(py_a, linalg::qrmode::r);
        EXPECT_TRUE(allclose(std::get<1>(res4), py_resq4_r_r));

    }
}
